import abc
import torch
import numpy as np
from torch.utils.data import Dataset
import matplotlib.pyplot as plt

class DiscreteToyDataset(Dataset, abc.ABC):
    def __iter__(self):
        return self
    
    def __next__(self):
        pass
        
    def __len__(self):
        pass

    def generate_cond(self, n_samples):
        pass
    
    def sample_from_condition(self, cond):
        pass

    def plot_samples(self, samples, cond, path):
        pass
    
class VectorInputMixture(DiscreteToyDataset):
    def __init__(self, batch_size, vectors, weights=None):
        self.batch_size = batch_size
        self.context_len = 1
        self.vectors = vectors
        self.n_classes = len(vectors)
        
        if weights is None:
            self.weights = torch.ones(self.n_classes).float() / self.n_classes
        else:
            weights_tensor = torch.tensor(weights, dtype=torch.float)
            self.weights = weights_tensor / weights_tensor.sum()
        
        lengths = [v.shape[0] for v in vectors]
        if len(set(lengths)) > 1:
            raise ValueError("All vectors must have the same length")
            
        self.length = lengths[0]
        self.vocab_size = self.length
        
        # Normalize each vector to create proper probability distributions
        self.v_probs = []
        self.full_vector = torch.zeros_like(vectors[0]).float()  # Initialize with zeros of correct shape
        for w, vector in zip(self.weights, vectors):
            probs = vector / vector.sum()
            self.v_probs.append(probs)
            self.full_vector += w * probs
        
        self.cond_dim = len(self.v_probs)
        
        # Create indices for all possible positions
        self.indices = torch.arange(self.length).unsqueeze(1)

    def __len__(self):
        return 1000

    def generate_cond(self, n_samples):
        return torch.multinomial(self.weights, n_samples, replacement=True).reshape(-1, 1)

    def sample_from_condition(self, cond):
        n_samples = cond.shape[0]
        samples = torch.zeros((n_samples, 1), dtype=torch.int64)
        
        # Sample for each class separately
        for class_idx in range(self.n_classes):
            # Find samples with this class
            mask = (cond.flatten() == class_idx)
            count = mask.sum().item()
            
            if count > 0:
                # Sample indices according to the probability distribution for this class
                indices = torch.multinomial(
                    self.v_probs[class_idx], 
                    count, 
                    replacement=True
                )
                # Assign to the correct positions in the output tensor
                samples[mask] = indices.unsqueeze(1)
                
        return samples

    def sample(self, n_samples=None):
        n = self.batch_size if n_samples is None else n_samples
        cond = self.generate_cond(n)
        samples = self.sample_from_condition(cond)
        return samples, cond
    
    def __iter__(self):
        return self

    def __next__(self):
        samples, cond = self.sample()
        return samples, cond
    
    def safe_exp(self, p, w):
        return torch.where(p > 0., p**w, 0.)

    def get_guided_distribution(self, class_idx, w):
        p_i = self.v_probs[class_idx]
        p = self.full_vector
        
        # Apply classifier guidance formula in log space for numerical stability
        # log(p_i^w * p^(1-w)) = w*log(p_i) + (1-w)*log(p)
        log_p_i = torch.log(torch.clamp(p_i, min=1e-10))
        log_p = torch.log(torch.clamp(p, min=1e-10))
        
        log_tempered = w * log_p_i + (1-w) * log_p
        tempered = torch.exp(log_tempered)
        
        # Normalize to get a valid probability distribution
        if tempered.sum() > 0:
            return tempered / tempered.sum()
        return tempered

    def plot_samples(self, samples, path=None, plot_vectors=True, fig=None, ax=None):
        # Convert samples to numpy for plotting
        samples = samples.detach().cpu().numpy().flatten()
        
        if plot_vectors:
            n_cols = self.n_classes + 2
            fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 4))
            
            # Create histogram of samples
            hist, _ = np.histogram(
                samples,
                bins=self.length,
                range=(0, self.length),
                density=True
            )
            
            to_plot = [*self.vectors, self.full_vector, hist]
            names = [*(f'Class {i}' for i in range(len(self.vectors))), 'Full Prob', 'Empirical']
            
            for i, (vec, name) in enumerate(zip(to_plot, names)):
                vec_np = vec.cpu().numpy() if torch.is_tensor(vec) else vec
                axes[i].bar(np.arange(self.length), vec_np)
                axes[i].set_title(name)
                axes[i].set_xlabel('Position')
                axes[i].set_ylabel('Probability')
                axes[i].set_xlim(-0.5, self.length-0.5)
        else:
            # Only plot the combined samples histogram
            if fig is None or ax is None:
                fig, ax = plt.subplots(figsize=(8, 4))
            
            # Create a histogram for all samples
            hist, bins = np.histogram(
                samples,
                bins=self.length,
                range=(0, self.length),
                density=True
            )
            
            # Plot the histogram
            ax.bar(np.arange(self.length), hist)
            ax.set_title('Samples Distribution')
            ax.set_xlabel('Position')
            ax.set_ylabel('Frequency')
            ax.set_xlim(-0.5, self.length-0.5)
        
        plt.tight_layout()
        if path is not None:
            fig.savefig(path)
            plt.close(fig)
        else:
            plt.show()


class MatrixInputMixture(DiscreteToyDataset):
    def __init__(self, batch_size, matrices, weights=None):
        """
        Initialize a dataset that samples from a mixture of probability matrices.
        
        Args:
            batch_size: Number of samples per batch
            matrices: A list of 2D tensors of probabilities, each with shape (height, width)
            weights: Optional list of weights for each matrix. If None, equal weights are used.
        """
        self.batch_size = batch_size
        self.context_len = 2
        self.matrices = matrices
        self.n_classes = len(matrices)
        
        # Set up weights for the mixture
        if weights is None:
            # Equal weights if none provided
            self.weights = torch.ones(self.n_classes) / self.n_classes
        else:
            # Normalize provided weights
            weights_tensor = torch.tensor(weights, dtype=torch.float)
            self.weights = weights_tensor / weights_tensor.sum()
        
        # Ensure all matrices have the same shape
        shapes = [m.shape for m in matrices]
        if len(set(shapes)) > 1:
            raise ValueError("All matrices must have the same shape")
            
        self.height, self.width = shapes[0]
        self.vocab_size = max(self.height, self.width)
        
        # Normalize each matrix to create proper probability distributions
        self.m_probs = []
        self.flat_probs = []
        self.full_matrix = 0
        for w, matrix in zip(self.weights, matrices):
            probs = matrix / matrix.sum()
            self.m_probs.append(probs)
            self.flat_probs.append(probs.flatten())
            self.full_matrix += w * probs
        
        self.cond_dim = len(self.m_probs)
        
        # Create indices for all possible (i,j) coordinates
        i_indices, j_indices = torch.meshgrid(
            torch.arange(self.height), 
            torch.arange(self.width),
            indexing='ij'
        )
        self.indices = torch.stack([i_indices.flatten(), j_indices.flatten()], dim=1)

    def __len__(self):
        return 1000

    def generate_cond(self, n_samples):
        return torch.multinomial(self.weights, n_samples, replacement=True).reshape(-1, 1)

    def sample_from_condition(self, cond):
        n_samples = cond.shape[0]
        samples = torch.zeros((n_samples, 2), dtype=torch.int64)
        
        # Sample for each class separately
        for class_idx in range(self.n_classes):
            # Find samples with this class
            mask = (cond.flatten() == class_idx)
            count = mask.sum().item()
            
            if count > 0:
                # Sample indices according to the probability distribution for this class
                flat_indices = torch.multinomial(
                    self.flat_probs[class_idx], 
                    count, 
                    replacement=True
                )
                # Convert to (i,j) coordinates
                class_samples = self.indices[flat_indices]
                # Assign to the correct positions in the output tensor
                samples[mask] = class_samples
                
        return samples

    def sample(self, n_samples=None):
        n = self.batch_size if n_samples is None else n_samples
        cond = self.generate_cond(n)
        samples = self.sample_from_condition(cond)
        return samples, cond
    
    def __iter__(self):
        return self

    def __next__(self):
        samples, cond = self.sample()
        return samples, cond
    
    def safe_exp(self, p, w):
        return torch.where(p > 0., p**w, 0.)

    def get_guided_distribution(self, class_idx, w):
        p_i = self.m_probs[class_idx]
        p = self.full_matrix
        m_iw = p_i**w
        m_w = self.safe_exp(p, 1-w)
        tempered = m_iw * m_w
        tempered = tempered/tempered.sum()

        dim_1 = tempered.sum(-1)
        dim_0 = tempered.sum(-2)
        ddd1 = self.safe_exp(p.sum(-1), (1-w)) * p_i.sum(-1)**w
        ddd0 = self.safe_exp(p.sum(-2), (1-w)) * p_i.sum(-2)**w

        inv_ci = torch.where(dim_0 != 0, ddd0 / dim_0, 1.)
        inv_cN = (ddd1).sum() / tempered.sum()
        inv_dj = torch.where(dim_1 != 0, ddd1 / dim_1, 1.)
        inv_dN = (ddd0).sum() / tempered.sum()

        den = inv_cN + inv_dN

        coeff_M = inv_ci.unsqueeze(0) + inv_dj.unsqueeze(1)
        coeff_M /= den
        
        return tempered * coeff_M

    def plot_matrix_with_annotations(self, ax, matrix, title, xlabel='j', ylabel='i'):
        """
        Helper method to plot a probability matrix with text annotations.
        
        Args:
            ax: The matplotlib axis to plot on
            matrix: The matrix to plot
            title: Title for the plot
            xlabel: Label for x-axis
            ylabel: Label for y-axis
        """
        # Normalize to get true probabilities if not already normalized
        matrix_prob = matrix / matrix.sum() if matrix.sum() > 0 else matrix
        
        im = ax.imshow(matrix_prob, cmap='viridis', origin='lower')
        ax.set_title(title)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        
        # Only add text annotations if the matrix is small enough
        # (typically matrices larger than 20x20 become too cluttered)
        max_size_for_annotations = 20
        if matrix.shape[0] <= max_size_for_annotations and matrix.shape[1] <= max_size_for_annotations:
            # Add text annotations with probability values
            for y in range(matrix_prob.shape[0]):
                for x in range(matrix_prob.shape[1]):
                    value = matrix_prob[y, x]
                    # Only show text for cells with non-zero probability
                    if value > 0:
                        # Format the value with appropriate precision
                        if value >= 0.01:
                            text = f'{value:.2f}'
                        else:
                            # For very small values, show truncated format
                            text = f'{value:.4f}'
                        # Choose text color based on background darkness
                        text_color = 'black'
                        # Use cell centers for text placement
                        ax.text(x, y, text, ha='center', va='center', 
                                color=text_color, fontsize=8)
        
        return im

    def plot_samples(self, samples, path=None, plot_matrices=True, fig=None, ax=None):
        # Convert samples to numpy for plotting
        samples = samples.detach().cpu().numpy()
        
        if plot_matrices:
            n_cols = self.n_classes + 2
            fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 5))
            
            hist, _, _ = np.histogram2d(
                samples[:, 1],
                samples[:, 0],
                bins=[self.width, self.height],
                range=[[0, self.width], [0, self.height]],
                density=True
            )

            to_plot = [*self.matrices, self.full_matrix, hist.T ]
            names = [*(f'Class {i}' for i in range(len(self.matrices))), 'Full Prob', 'Empirical']

            for i, (mat, name) in enumerate(zip(to_plot, names)):
                self.plot_matrix_with_annotations(
                    axes[i], 
                    mat.cpu().numpy() if torch.is_tensor(mat) else mat, 
                    name
                )
        else:
            # Only plot the combined samples histogram
            if fig is None or ax is None:
                fig, ax = plt.subplots(figsize=(8, 8))
            
            # Create a histogram for all samples
            hist, _, _ = np.histogram2d(
                samples[:, 1],  # j coordinate (x-axis)
                samples[:, 0],  # i coordinate (y-axis)
                bins=[self.width, self.height],
                range=[[0, self.width], [0, self.height]],
                density=True
            )
            
            # Plot the histogram as a matrix
            self.plot_matrix_with_annotations(
                ax, 
                hist.T,  # Transpose to match matrix orientation
                'Samples Distribution'
            )
        
        plt.tight_layout()
        if path is not None:
            fig.savefig(path)
            plt.close(fig)
        else:
            plt.show()


class GaussianMixture5D(DiscreteToyDataset):
    def __init__(self, batch_size, means, covariances, weights=None, grid_size=10):
        self.batch_size = batch_size
        self.context_len = 5
        self.vocab_size = grid_size
        self.grid_size = grid_size
        
        # Convert inputs to torch tensors
        self.means = [torch.tensor(mean, dtype=torch.float) for mean in means]
        self.covariances = [torch.tensor(cov, dtype=torch.float) for cov in covariances]
        self.n_components = len(means)
        
        # Set up weights for the mixture
        if weights is None:
            self.weights = torch.ones(self.n_components) / self.n_components
        else:
            weights_tensor = torch.tensor(weights, dtype=torch.float)
            self.weights = weights_tensor / weights_tensor.sum()
        
        # Create grid for discretization
        self.grid_points = torch.linspace(-3, 3, grid_size)
        
        # Precompute probability density on the grid
        self.precompute_densities()
    
    def precompute_densities(self):
        """Precompute the probability density for each component on the grid."""
        from torch.distributions.multivariate_normal import MultivariateNormal
        
        # Create distribution objects for each component
        self.distributions = [
            MultivariateNormal(mean, covariance)
            for mean, covariance in zip(self.means, self.covariances)
        ]
        
        # Store the analytical PDF for reference
        self.analytical_pdf = lambda x: sum(
            weight * dist.log_prob(x).exp()
            for weight, dist in zip(self.weights, self.distributions)
        )
    
    def generate_cond(self, n_samples):
        """Generate component indices as conditions."""
        return torch.multinomial(self.weights, n_samples, replacement=True).reshape(-1, 1)
    
    def sample_from_condition(self, cond):
        """Sample from the specified mixture component."""
        n_samples = cond.shape[0]
        samples = torch.zeros((n_samples, 5), dtype=torch.int64)
        
        # Sample for each component separately
        for comp_idx in range(self.n_components):
            mask = (cond.flatten() == comp_idx)
            count = mask.sum().item()
            
            if count > 0:
                # Generate continuous samples from the Gaussian
                dist = self.distributions[comp_idx]
                continuous_samples = dist.sample((count,))
                
                # Discretize by finding closest grid point for each dimension
                for dim in range(5):
                    dim_samples = continuous_samples[:, dim]
                    # Find closest grid point index for each sample
                    grid_indices = torch.abs(dim_samples.unsqueeze(1) - self.grid_points).argmin(dim=1)
                    samples[mask, dim] = grid_indices
        
        return samples
    
    def get_guided_distribution(self, class_idx, w):
        """
        Apply classifier guidance to get a guided distribution between 
        a specific component and the full mixture.
        
        Args:
            class_idx: Index of the component to guide towards
            w: Guidance weight (0 = full mixture, 1 = only the component)
        
        Returns:
            A tensor representing the guided distribution over the discretized 5D space
        """
        # For 5D, we need to create a meshgrid of all coordinates
        # This is memory-intensive, so we'll create coordinate tensors directly
        coords = []
        for dim in range(5):
            shape = [1] * 5
            shape[dim] = self.grid_size
            grid_coord = self.grid_points.view(shape)
            expand_shape = [self.grid_size] * 5
            coords.append(grid_coord.expand(expand_shape))
        
        # Stack coordinates to get positions of all grid points
        # Each position is a 5D vector
        positions = torch.stack([coord.flatten() for coord in coords], dim=1)
        
        # Component density p_i(x)
        dist_i = self.distributions[class_idx]
        log_p_i = dist_i.log_prob(positions)
        
        # Full mixture density p(x)
        log_p_mix = torch.logsumexp(torch.stack([
            torch.log(self.weights[j]) + self.distributions[j].log_prob(positions)
            for j in range(self.n_components)
        ], dim=0), dim=0)
        
        # Apply classifier guidance in log space
        log_guided = w * log_p_i + (1-w) * log_p_mix
        
        # Convert to probability and reshape to 5D grid
        guided_density = torch.exp(log_guided)
        guided_density = guided_density / guided_density.sum()  # Normalize
        
        # Reshape to 5D grid
        guided_density = guided_density.view([self.grid_size] * 5)
        
        return guided_density, log_p_mix.exp()
        
    
    def sample(self, n_samples=None):
        n = self.batch_size if n_samples is None else n_samples
        cond = self.generate_cond(n)
        samples = self.sample_from_condition(cond)
        return samples, cond
    
    def __iter__(self):
        return self
    
    def __next__(self):
        samples, cond = self.sample()
        return samples, cond
    
    def __len__(self):
        return 1000
    
    def plot_samples(self, samples, path=None):
        """
        Plot 2D projections of the 5D samples.
        For 5D data, we'll create a grid of 2D projections.
        """
        samples = samples.detach().cpu().numpy()
        
        # Create a grid of 2D projections (10 pairs of dimensions)
        fig, axes = plt.subplots(2, 5, figsize=(15, 6))
        axes = axes.flatten()
        
        # Define the dimension pairs to plot
        dim_pairs = [(0,1), (0,2), (0,3), (0,4), (1,2), 
                     (1,3), (1,4), (2,3), (2,4), (3,4)]
        
        for i, (dim1, dim2) in enumerate(dim_pairs):
            if i < len(axes):
                ax = axes[i]
                # Convert grid indices back to continuous values for visualization
                x = self.grid_points[samples[:, dim1]]
                y = self.grid_points[samples[:, dim2]]
                
                # Create a 2D histogram
                h, xedges, yedges = np.histogram2d(
                    x, y, bins=self.grid_size, 
                    range=[[-3, 3], [-3, 3]], density=True
                )
                
                # Plot as a heatmap
                ax.imshow(h.T, origin='lower', extent=[-3, 3, -3, 3], 
                            aspect='auto', cmap='viridis')
                
                ax.set_title(f'Dims {dim1+1} vs {dim2+1}')
                ax.set_xlabel(f'Dimension {dim1+1}')
                ax.set_ylabel(f'Dimension {dim2+1}')
        
        # Add a colorbar
        # fig.colorbar(im, ax=axes, shrink=0.6, label='Density')
        
        plt.tight_layout()
        if path is not None:
            fig.savefig(path)
            plt.close(fig)
        else:
            plt.show()

            

def get_dataset(name, batch_size=6):
    if name == 'vector-disjoint': 
        import torch
        vectors = torch.tensor([
            [0.1, 0.2, 0.4, 0.2, 0.1, 0.0, 0.0, 0.0, 0.1, 0.2, 0.4, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],  # Left-shifted distribution
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.2, 0.4, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.1, 0.2, 0.4, 0.2, 0.1, 0.0, 0.0],  # Left-shifted distribution
        ])
        return VectorInputMixture(batch_size=batch_size,
                                  vectors=vectors,
                                  weights=[.5, .5])
    elif name == 'vector-intersection': 
        import torch
        vectors = torch.tensor([
            [0.1, 0.2, 0.4, 0.2, 0.1, 0.0, 0.0, 0.0, 0.1, 0.2, 0.4, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],  # Left-shifted distribution
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.2, 0.4, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.1, 0.2, 0.4, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],  # Left-shifted distribution
        ])
        return VectorInputMixture(batch_size=batch_size,
                                  vectors=vectors,
                                  weights=[.5, .5])
    elif name == 'matrix-disjoint':
        import torch
        
        # Matrix size
        height, width = 30, 30
        
        matrix1 = torch.zeros((height, width))
        cluster = [
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.1, 0.2, 0.1, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.1, 0.3, 0.5, 0.3, 0.1, 0.0, 0.0],
            [0.0, 0.1, 0.3, 0.6, 0.8, 0.6, 0.3, 0.1, 0.0],
            [0.1, 0.2, 0.5, 0.8, 1.0, 0.8, 0.5, 0.2, 0.1],
            [0.0, 0.1, 0.3, 0.6, 0.8, 0.6, 0.3, 0.1, 0.0],
            [0.0, 0.0, 0.1, 0.3, 0.5, 0.3, 0.1, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.1, 0.2, 0.1, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0],
        ]
        matrix1[7:16, 7:16] = torch.tensor(cluster)
        matrix1[1:10, 1:10] = torch.maximum(matrix1[1:10, 1:10], torch.tensor(cluster))
        
        matrix2 = torch.zeros((height, width))
        matrix2[13:22, 13:22] = torch.tensor(cluster)
        matrix2[19:28, 19:28] = torch.maximum(matrix2[19:28, 19:28], torch.tensor(cluster))
        weights = [0.5, 0.5]
        return MatrixInputMixture(batch_size=batch_size, matrices=[matrix1, matrix2], weights=weights)
    
    elif name == 'matrix-intersection':
        import torch
        
        # Matrix size
        height, width = 30, 30
        
        matrix1 = torch.zeros((height, width))
        cluster = [
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.1, 0.2, 0.1, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.1, 0.3, 0.5, 0.3, 0.1, 0.0, 0.0],
            [0.0, 0.1, 0.3, 0.6, 0.8, 0.6, 0.3, 0.1, 0.0],
            [0.1, 0.2, 0.5, 0.8, 1.0, 0.8, 0.5, 0.2, 0.1],
            [0.0, 0.1, 0.3, 0.6, 0.8, 0.6, 0.3, 0.1, 0.0],
            [0.0, 0.0, 0.1, 0.3, 0.5, 0.3, 0.1, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.1, 0.2, 0.1, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0],
        ]
        matrix1[1:10, 1:10] = torch.tensor(cluster)
        matrix1[9:18, 9:18] = torch.tensor(cluster)

        matrix2 = torch.zeros((height, width))
        matrix2[11:20, 11:20] = torch.tensor(cluster)
        matrix2[19:28, 19:28] = torch.tensor(cluster)
        

        weights = [0.5, 0.5]
        return MatrixInputMixture(batch_size=batch_size, matrices=[matrix1, matrix2], weights=weights)
    elif name == 'gaussian-5d':
        import torch

        # Define 3 Gaussian components in 5D with strategic overlaps
        means = [
            [0.0, 0.0, 0.0, 0.0, 0.0],           # Center
            [2.0, 2.0, 0.0, -1.0, -1.0],         # Overlaps with center on dims 2,3,4
            [-2.0, -2.0, 0.0, 1.0, 1.0]          # Overlaps with center on dims 2,3,4
        ]
        
        # Adjusted covariances to create more visible intersections
        covariances = [
            torch.diag(torch.tensor([0.8, 0.8, 0.5, 0.5, 0.5])) / .4,  # Center cluster
            torch.diag(torch.tensor([0.8, 0.8, 0.5, 0.5, 0.5])) / .8,  # Second cluster
            torch.diag(torch.tensor([0.8, 0.8, 0.5, 0.5, 0.5])) / .8  # Third cluster
        ]
        
        weights = [0.4, 0.3, 0.3]
        
        return GaussianMixture5D(
            batch_size=batch_size,
            means=means,
            covariances=covariances,
            weights=weights,
            grid_size=10  # Discretize each dimension into 10 bins
        )
    else:
        print('Dataset is not implemented')
